| Row | age | male | rate | true_t | t | survived | idx |
|---|---|---|---|---|---|---|---|
| Int64 | Bool | Float64 | Float64 | Float64 | Bool | Int64 | |
| 1 | 51 | false | 0.0497871 | 20.672 | 20.0 | true | 1 |
| 2 | 51 | false | 0.0497871 | 24.6577 | 20.0 | true | 2 |
| 3 | 50 | false | 0.0497871 | 31.853 | 20.0 | true | 3 |
| 4 | 55 | false | 0.0497871 | 9.7404 | 9.7404 | false | 4 |
| 5 | 46 | false | 0.0497871 | 15.3396 | 15.3396 | false | 5 |
| 6 | 45 | false | 0.0497871 | 41.6487 | 20.0 | true | 6 |
| 7 | 56 | false | 0.0497871 | 12.7356 | 12.7356 | false | 7 |
| 8 | 47 | true | 0.082085 | 0.761362 | 0.761362 | false | 8 |
| 9 | 61 | false | 0.0497871 | 6.27149 | 6.27149 | false | 9 |
| 10 | 57 | false | 0.0497871 | 16.5202 | 16.5202 | false | 10 |
| 11 | 54 | true | 0.082085 | 17.3104 | 17.3104 | false | 11 |
| 12 | 49 | false | 0.0497871 | 0.787504 | 0.787504 | false | 12 |
| 13 | 56 | true | 0.082085 | 2.42708 | 2.42708 | false | 13 |
| ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ |
| 89 | 59 | false | 0.0497871 | 21.8631 | 20.0 | true | 89 |
| 90 | 50 | true | 0.082085 | 3.82811 | 3.82811 | false | 90 |
| 91 | 55 | true | 0.082085 | 12.3331 | 12.3331 | false | 91 |
| 92 | 56 | true | 0.082085 | 29.3846 | 20.0 | true | 92 |
| 93 | 47 | true | 0.082085 | 18.4165 | 18.4165 | false | 93 |
| 94 | 51 | false | 0.0497871 | 25.4902 | 20.0 | true | 94 |
| 95 | 65 | false | 0.0497871 | 21.1051 | 20.0 | true | 95 |
| 96 | 67 | true | 0.082085 | 19.4895 | 19.4895 | false | 96 |
| 97 | 62 | false | 0.0497871 | 8.24189 | 8.24189 | false | 97 |
| 98 | 65 | false | 0.0497871 | 31.491 | 20.0 | true | 98 |
| 99 | 53 | true | 0.082085 | 6.39705 | 6.39705 | false | 99 |
| 100 | 67 | true | 0.082085 | 6.19111 | 6.19111 | false | 100 |
Reproducing example models from survivalstan
Survival analysis - what’s that?
According to Wikipedia:
Survival analysis is a branch of statistics for analyzing the expected duration of time until one event occurs, such as death in biological organisms and failure in mechanical systems.
We’ll consider the setting used for the examples at https://jburos.github.io/survivalstan/Examples.html. We will have a model which, for a set of persons, takes
- a set of covariates (age and gender per person),
- a list of times at which the event either occurs, or until which the event did not occur (one time and event/survival indicator per person),
and, after following standard Bayesian procedures via conditioning on observations, yields a way to predict the survival time of unobserved persons, given the same covariates.
For fixed covariates \(x\) and model parameters \(\theta\), the models below will give us a way to compute a (piecewise exponential) survival function \(S(t) = Pr(T > t)\), i.e. a function which models the probability that the event in question has not occured until the specified time \(t\). Usually as well as in our setting, the survival function will be the solution to a simple linear first-order differential equation with variable coefficients, concretely we have
\[ S'(t) = -\lambda(t)S(t)\quad\text{and}\quad S(0) = 1 \] where the hazard function/rate \(\lambda(t)\) is a non-negative function, such that \(S(t)\) is monotonically non-increasing and has values in \((0, 1]\). The log of the survival function is then \[ \log S(t) = -\int_0^t\lambda(\tau) d\tau. \]
As \(S(t)\) models the survival (the non-occurence of an event), the log likelihood of the occurence of an event at a given time \(t\) is \[ \log p_1(t) = \log -S'(t) = \log \lambda(t) + \log S(t) = \log \lambda(t) -\int_0^t\lambda(\tau) d\tau \] and the log likelihood of survival up to at least time \(t\) is \[ \log p_0(t) = \log S(t) = -\int_0^t\lambda(\tau) d\tau. \]
The first term (\(p_1(t)\)) will have to be used for the likelihood contribution of observations of the event occuring (survival up to exactly time \(t\)), while the second term (\(p_0(t)\)) will have to be used for the likelihood contribution of observations of the event not ocurring until the end of the observation time, aka as censored observations.
If the hazard function \(\lambda(\tau)\) is constant and if we do not care about constant terms (as e.g. during MCMC) we can use the Poisson distribution to compute the appropriate terms “automatically”. For piecewise constant hazard functions, it’s possible to chain individual Poisson likelihoods to compute the overall likelihood (modulo a constant term).
For piecewise constant hazard functions of the form \[ \lambda(t) = \begin{cases} \lambda_1 & \text{if } t \in [t_0, t_1],\\ \lambda_2 & \text{if } t \in (t_1, t_2],\\ \dots \end{cases} \] with \(0 = t_0 < t_1 < t_2 < \dots\) the survival function can be directly computed as \[ \log S(t_j) = -\sum_{i=1}^j (t_i-t_{i-1}) \lambda_i. \]
Why do this? Why reimplement things?
Out of curiosity, to figure out whether. and to demonstrate that I understand survival analysis. Writing down the math is nice and all, but to get correct simulation results, every little detail has to be right. At least in principle, in practice the simulation can still be subtly wrong due to errors which don’t crash everything, but only e.g. introduce biases.
Simulation
Simulated data
To simulate the data, we generate (for 100 persons)
- the
agefrom a Poisson distribution with mean 55, - the gender (
maleor not) from a Bernoulli distribution with mean 1/2, - assume a constant (in time) hazard function, computed from
ageandmaleaslog(hazard) = -3 + .5 * male, - draw true survival times
true_tfrom an Exponential distribution with rate parameterhazard, - cap them at a
censor_timeof 20, i.e.t = min(true_t, censor_time), and - set
survivedtotrueiftrue_t > censor_timeand false otherwise.
Currently, the used formula/rate_form is hardcoded to match the examples.
sim_data_exp_correlated(rng=Random.default_rng(); N, censor_time, rate_form, rate_coefs) = begin
idx = 1:N
age = rand(rng, Poisson(55), N)
male = rand(rng, Bernoulli(.5), N)
rate = @. exp(rate_coefs[1] + male * rate_coefs[2])
true_t = rand.(rng, ConstantExponentialModel.(rate))
t = min.(true_t, censor_time)
survived = true_t .> censor_time
DataFrame((;age, male, rate, true_t, t, survived, idx))
endFor all simulations,
- we model the hazard function \(\lambda_i(t)\) of person \(i = 1,\dots,100\) to be piecewise constant, with as many pieces as there are unique event times, plus a final one which goes from the largest event observation time to the censor,
- every person’s hazard function is unique (provided the covariates are unique),
- the personwise (\(i\)) and timeslabwise (\(j\)) hazard values will be of the form \[ \log\lambda_{i,j} = \log a + \log\kappa_j + \langle{}X_i,\beta_j\rangle{}, \] where \(\log a\) is a scalar intercept, \(\log\kappa_j\) is a time-varying (but person-constant) effect, \(X_i\) are the \(i\)-th person’s covariates, and \(\beta_j\) are the potentially time-varying covariate effects (in timeslab \(j\)). For the first two models, \(\beta\) will be constant, while it will vary for the last model.
pem_survival_model
The easiest model. The covariate effects are constant (\(\beta_1=\beta_2=\dots\)) and the time-varying (but person-constant) effect \(\log\kappa_j\) has a hierarchical normal prior with mean 0 and unkown scale (with standard half-normal prior). There seems to be small mistake in the original model, where at line 42 (AFAICT) log_t_dur = log(t_obs) assign the logarithm of the event time to the variable which has to contain the logarithm of the timeslab width.
function pem_survival_model(;
survived,
t,
design_matrix,
likelihood=true
)
(;
n_persons, n_covariates, t1, n_timepoints, end_idxs, t0, dt, log_dts
) = prepare_survival(;t, design_matrix)
StanBlocks.@stan begin
@parameters begin
log_hazard_intercept::real
beta::vector[n_covariates]
log_hazard_timewise_scale::real(lower=0)
log_hazard_timewise::vector[n_timepoints]
end
log_hazard_personwise = design_matrix*beta
StanBlocks.@model @views begin
log_hazard_intercept ~ normal(0, 1)
beta ~ cauchy(0, 2)
log_hazard_timewise_scale ~ normal(0, 1)
log_hazard_timewise ~ normal(0, log_hazard_timewise_scale)
log_lik = Base.broadcast(1:n_persons) do person
idxs = 1:end_idxs[person]
survival_lpdf(
survived[person],
StanBlocks.@broadcasted(log_hazard_intercept + log_hazard_personwise[person] + log_hazard_timewise[idxs]),
log_dts[idxs]
)
end
likelihood && (target += sum(log_lik))
end
StanBlocks.@generated_quantities begin
log_lik = collect(log_lik)
t_pred = map(1:n_persons) do person
for timepoint in 1:n_timepoints
log_hazard = log_hazard_intercept + log_hazard_personwise[person] + log_hazard_timewise[timepoint]
rv = rand(Exponential(exp(-log_hazard)))
rv <= dt[timepoint] && return t0[timepoint] + rv
end
t1[end]
end
end
end
end/* Variable naming:
// dimensions
N = total number of observations (length of data)
S = number of sample ids
T = max timepoint (number of timepoint ids)
M = number of covariates
// main data matrix (per observed timepoint*record)
s = sample id for each obs
t = timepoint id for each obs
event = integer indicating if there was an event at time t for sample s
x = matrix of real-valued covariates at time t for sample n [N, X]
// timepoint-specific data (per timepoint, ordered by timepoint id)
t_obs = observed time since origin for each timepoint id (end of period)
t_dur = duration of each timepoint period (first diff of t_obs)
*/
// Jacqueline Buros Novik <jackinovik@gmail.com>
data {
// dimensions
int<lower=1> N;
int<lower=1> S;
int<lower=1> T;
int<lower=0> M;
// data matrix
int<lower=1, upper=N> s[N]; // sample id
int<lower=1, upper=T> t[N]; // timepoint id
int<lower=0, upper=1> event[N]; // 1: event, 0:censor
matrix[N, M] x; // explanatory vars
// timepoint data
vector<lower=0>[T] t_obs;
vector<lower=0>[T] t_dur;
}
transformed data {
vector[T] log_t_dur; // log-duration for each timepoint
int n_trans[S, T];
log_t_dur = log(t_obs);
// n_trans used to map each sample*timepoint to n (used in gen quantities)
// map each patient/timepoint combination to n values
for (n in 1:N) {
n_trans[s[n], t[n]] = n;
}
// fill in missing values with n for max t for that patient
// ie assume "last observed" state applies forward (may be problematic for TVC)
// this allows us to predict failure times >= observed survival times
for (samp in 1:S) {
int last_value;
last_value = 0;
for (tp in 1:T) {
// manual says ints are initialized to neg values
// so <=0 is a shorthand for "unassigned"
if (n_trans[samp, tp] <= 0 && last_value != 0) {
n_trans[samp, tp] = last_value;
} else {
last_value = n_trans[samp, tp];
}
}
}
}
parameters {
vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t
vector[M] beta; // beta for each covariate
real<lower=0> baseline_sigma;
real log_baseline_mu;
}
transformed parameters {
vector[N] log_hazard;
vector[T] log_baseline; // unstructured baseline hazard for each timepoint t
log_baseline = log_baseline_mu + log_baseline_raw + log_t_dur;
for (n in 1:N) {
log_hazard[n] = log_baseline[t[n]] + x[n,]*beta;
}
}
model {
beta ~ cauchy(0, 2);
event ~ poisson_log(log_hazard);
log_baseline_mu ~ normal(0, 1);
baseline_sigma ~ normal(0, 1);
log_baseline_raw ~ normal(0, baseline_sigma);
}
generated quantities {
real log_lik[N];
vector[T] baseline;
real y_hat_time[S]; // predicted failure time for each sample
int y_hat_event[S]; // predicted event (0:censor, 1:event)
// compute raw baseline hazard, for summary/plotting
baseline = exp(log_baseline_mu + log_baseline_raw);
// prepare log_lik for loo-psis
for (n in 1:N) {
log_lik[n] = poisson_log_log(event[n], log_hazard[n]);
}
// posterior predicted values
for (samp in 1:S) {
int sample_alive;
sample_alive = 1;
for (tp in 1:T) {
if (sample_alive == 1) {
int n;
int pred_y;
real log_haz;
// determine predicted value of this sample's hazard
n = n_trans[samp, tp];
log_haz = log_baseline[tp] + x[n,] * beta;
// now, make posterior prediction of an event at this tp
if (log_haz < log(pow(2, 30)))
pred_y = poisson_log_rng(log_haz);
else
pred_y = 9;
// summarize survival time (observed) for this pt
if (pred_y >= 1) {
// mark this patient as ineligible for future tps
// note: deliberately treat 9s as events
sample_alive = 0;
y_hat_time[samp] = t_obs[tp];
y_hat_event[samp] = 1;
}
}
} // end per-timepoint loop
// if patient still alive at max
if (sample_alive == 1) {
y_hat_time[samp] = t_obs[T];
y_hat_event[samp] = 0;
}
} // end per-sample loop
}pem_survival_model_randomwalk
Identical to the first model, except that the time-varying (but person-constant) effect \(\log\kappa_j\) should have a “random walk” prior. AFAICT, the original model has the same small mistake as the first one (this time at line 43), but IMO some (minor) other things goes “wrong” in constructing the “random walk” prior, or rather, I believe that instead of a random walk prior as implemented in the original code, an approximate Brownian motion / Wiener process prior would have been a better choice:
A random walk prior as implemented in the original code will imply different priors for different numbers of persons and also for different realizations of the event times, while an approximate Wiener process prior does not (or rather, much less). Consider the following:
(Gaussian) random walk prior
For random walk parameters \(x_1, x_2, \dots\) with scale parameter \(\sigma\), the (conditional) prior density is \[ p(x_i | x_{i-1}) = p_\mathcal{N}(x_i | x_{i-1}, \sigma^2) \text{ for } i=1,2,\dots \] and with \(x_0\) another parameter with appropriate prior.
Approximate (Gaussian) Wiener process prior
Following Wikipedia:
The Wiener process \(W_t\) is characterised by the following properties: […] W has Gaussian increments: […] \(W_{t+i} - W_t \sim \mathcal{N}(0,u)\).
I.e., for timepoints \(0 = t_0 < t_1 < t_2 < \dots\) as above, the (conditional) prior density of the (shifted) Wiener process values \(x_1, x_2, \dots\) with scale parameter \(\sigma\) is \[ p(x_i | x_{i-1}) = p_\mathcal{N}(x_i | x_{i-1}, (t_i-t_{i-1})\sigma^2) \text{ for } i=1,2,\dots \] and with \(x_0\) as before.
Dependence on the observed event times
The difference between the two priors will become most easily apparent by looking at the implied prior on the (log) hazard at (or right before) the censor time \(t_\text{censor} = t_{N+1}\), for varying numbers of unique observed event times \(N\). For the random walk prior, we’ll have \[ x_j \sim \mathcal{N}(x_0, j\sigma^2) \text{ for } j = 1,\dots,N+1, \] while for the Wiener process prior, we’ll have \[ x_j \sim \mathcal{N}(0, t_j\sigma^2) \text{ for } j = 1,\dots,N+1. \] In particular, for \(j=N+1\) (i.e. at censor time), we get a constant prior distribution for the Wiener process prior, but for the random walk prior we get a prior distribution that depends on the number of unique observed event times \(N\). Similarly, even for fixed \(N\), there is a (potentially strong) dependence of the implied prior for “interior” time slabs on the realization of the even times for the random walk prior, while there’s “no” dependence of the implied prior for the Wiener process prior. Caveat: There will actually be a dependence of the implied prior on the event time realizations also for the Wiener process, but this is only due to the piecewise-constant “assumption” and can be interpreted as an approximation error to the solution of the underlying stochastic differential equation.
function pem_survival_model_randomwalk(;
survived,
t,
design_matrix,
likelihood=true
)
(;
n_persons, n_covariates, t1, n_timepoints, end_idxs, t0, dt, log_dts
) = prepare_survival(;t, design_matrix)
rw_sqrt_scale = @. sqrt(.5*(dt[1:end-1] + dt[2:end]))
StanBlocks.@stan begin
@parameters begin
log_hazard_intercept::real
beta::vector[n_covariates]
log_hazard_timewise_scale::real(lower=0)
log_hazard_timewise::vector[n_timepoints]
end
log_hazard_personwise = design_matrix*beta
StanBlocks.@model @views begin
log_hazard_intercept ~ normal(0, 1)
beta ~ cauchy(0, 2)
log_hazard_timewise_scale ~ normal(0, 1)
log_hazard_timewise[1] ~ normal(0, 1)
log_hazard_timewise ~ random_walk(
StanBlocks.@broadcasted(log_hazard_timewise_scale * rw_sqrt_scale)
)
log_lik = Base.broadcast(1:n_persons) do person
idxs = 1:end_idxs[person]
survival_lpdf(
survived[person],
StanBlocks.@broadcasted(log_hazard_intercept + log_hazard_personwise[person] + log_hazard_timewise[idxs]),
log_dts[idxs]
)
end
likelihood && (target += sum(log_lik))
end
StanBlocks.@generated_quantities begin
log_lik = collect(log_lik)
t_pred = map(1:n_persons) do person
for timepoint in 1:n_timepoints
log_hazard = log_hazard_intercept + log_hazard_personwise[person] + log_hazard_timewise[timepoint]
rv = rand(Exponential(exp(-log_hazard)))
rv <= dt[timepoint] && return t0[timepoint] + rv
end
t1[end]
end
end
end
end/* Variable naming:
// dimensions
N = total number of observations (length of data)
S = number of sample ids
T = max timepoint (number of timepoint ids)
M = number of covariates
// main data matrix (per observed timepoint*record)
s = sample id for each obs
t = timepoint id for each obs
event = integer indicating if there was an event at time t for sample s
x = matrix of real-valued covariates at time t for sample n [N, X]
// timepoint-specific data (per timepoint, ordered by timepoint id)
t_obs = observed time since origin for each timepoint id (end of period)
t_dur = duration of each timepoint period (first diff of t_obs)
*/
// Jacqueline Buros Novik <jackinovik@gmail.com>
data {
// dimensions
int<lower=1> N;
int<lower=1> S;
int<lower=1> T;
int<lower=0> M;
// data matrix
int<lower=1, upper=N> s[N]; // sample id
int<lower=1, upper=T> t[N]; // timepoint id
int<lower=0, upper=1> event[N]; // 1: event, 0:censor
matrix[N, M] x; // explanatory vars
// timepoint data
vector<lower=0>[T] t_obs;
vector<lower=0>[T] t_dur;
}
transformed data {
vector[T] log_t_dur; // log-duration for each timepoint
int n_trans[S, T];
log_t_dur = log(t_obs);
// n_trans used to map each sample*timepoint to n (used in gen quantities)
// map each patient/timepoint combination to n values
for (n in 1:N) {
n_trans[s[n], t[n]] = n;
}
// fill in missing values with n for max t for that patient
// ie assume "last observed" state applies forward (may be problematic for TVC)
// this allows us to predict failure times >= observed survival times
for (samp in 1:S) {
int last_value;
last_value = 0;
for (tp in 1:T) {
// manual says ints are initialized to neg values
// so <=0 is a shorthand for "unassigned"
if (n_trans[samp, tp] <= 0 && last_value != 0) {
n_trans[samp, tp] = last_value;
} else {
last_value = n_trans[samp, tp];
}
}
}
}
parameters {
vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t
vector[M] beta; // beta for each covariate
real<lower=0> baseline_sigma;
real log_baseline_mu;
}
transformed parameters {
vector[N] log_hazard;
vector[T] log_baseline;
log_baseline = log_baseline_raw + log_t_dur;
for (n in 1:N) {
log_hazard[n] = log_baseline_mu + log_baseline[t[n]] + x[n,]*beta;
}
}
model {
beta ~ cauchy(0, 2);
event ~ poisson_log(log_hazard);
log_baseline_mu ~ normal(0, 1);
baseline_sigma ~ normal(0, 1);
log_baseline_raw[1] ~ normal(0, 1);
for (i in 2:T) {
log_baseline_raw[i] ~ normal(log_baseline_raw[i-1], baseline_sigma);
}
}
generated quantities {
real log_lik[N];
vector[T] baseline;
int y_hat_mat[S, T]; // ppcheck for each S*T combination
real y_hat_time[S]; // predicted failure time for each sample
int y_hat_event[S]; // predicted event (0:censor, 1:event)
// compute raw baseline hazard, for summary/plotting
baseline = exp(log_baseline_raw);
for (n in 1:N) {
log_lik[n] <- poisson_log_lpmf(event[n] | log_hazard[n]);
}
// posterior predicted values
for (samp in 1:S) {
int sample_alive;
sample_alive = 1;
for (tp in 1:T) {
if (sample_alive == 1) {
int n;
int pred_y;
real log_haz;
// determine predicted value of y
// (need to recalc so that carried-forward data use sim tp and not t[n])
n = n_trans[samp, tp];
log_haz = log_baseline_mu + log_baseline[tp] + x[n,]*beta;
if (log_haz < log(pow(2, 30)))
pred_y = poisson_log_rng(log_haz);
else
pred_y = 9;
// mark this patient as ineligible for future tps
// note: deliberately make 9s ineligible
if (pred_y >= 1) {
sample_alive = 0;
y_hat_time[samp] = t_obs[tp];
y_hat_event[samp] = 1;
}
// save predicted value of y to matrix
y_hat_mat[samp, tp] = pred_y;
}
else if (sample_alive == 0) {
y_hat_mat[samp, tp] = 9;
}
} // end per-timepoint loop
// if patient still alive at max
//
if (sample_alive == 1) {
y_hat_time[samp] = t_obs[T];
y_hat_event[samp] = 0;
}
} // end per-sample loop
}pem_survival_model_timevarying
To be finished. To keep things short:
- The original model has the same minor problems as the other models.
- While the original model implements a random walk prior on the increments of the covariate effects, I’ve kept things a bit simpler and instead just implemented the corresponding Wiener process prior on the values of the covariate effects. IMO, putting a given prior on the increments instead of on the values or vice versa is a modeling decision, and not a “mistake” by any stretch of the imagination. Doing one or the other implies different things, and which choice is “better” is not clear a priori and may depend on the setting.
- I believe sampling may have failed a bit for the run included in this notebook. I believe I have seen better sampling “runs”, but as this doesn’t have to be perfect, I’ve left it as is.
function pem_survival_model_timevarying(;
survived,
t,
design_matrix,
likelihood=true
)
(;
n_persons, n_covariates, t1, n_timepoints, end_idxs, t0, dt, log_dts
) = prepare_survival(;t, design_matrix)
rw_sqrt_scale = @. sqrt(.5*(dt[1:end-1] + dt[2:end]))
StanBlocks.@stan begin
@parameters begin
log_hazard_intercept::real
beta_timewise_scale::real(lower=0)
beta_timewise::matrix[n_covariates, n_timepoints]
log_hazard_timewise_scale::real(lower=0)
log_hazard_timewise::vector[n_timepoints]
end
log_hazard_personwise = design_matrix*beta_timewise
StanBlocks.@model @views begin
log_hazard_intercept ~ normal(0, 1)
beta_timewise_scale ~ cauchy(0, 1)
beta_timewise[:, 1] ~ cauchy(0, 1)
beta_timewise' ~ random_walk(
StanBlocks.@broadcasted(beta_timewise_scale * rw_sqrt_scale)
)
log_hazard_timewise_scale ~ normal(0, 1)
log_hazard_timewise[1] ~ normal(0, 1)
log_hazard_timewise ~ random_walk(
StanBlocks.@broadcasted(log_hazard_timewise_scale * rw_sqrt_scale)
)
log_lik = Base.broadcast(1:n_persons) do person
idxs = 1:end_idxs[person]
survival_lpdf(
survived[person],
StanBlocks.@broadcasted(log_hazard_intercept + log_hazard_personwise[person, idxs] + log_hazard_timewise[idxs]),
log_dts[idxs]
)
end
likelihood && (target += sum(log_lik))
end
StanBlocks.@generated_quantities begin
log_lik = collect(log_lik)
t_pred = map(1:n_persons) do person
for timepoint in 1:n_timepoints
log_hazard = log_hazard_intercept + log_hazard_personwise[person, timepoint] + log_hazard_timewise[timepoint]
rv = rand(Exponential(exp(-log_hazard)))
rv <= dt[timepoint] && return t0[timepoint] + rv
end
t1[end]
end
end
end
end/* Variable naming:
// dimensions
N = total number of observations (length of data)
S = number of sample ids
T = max timepoint (number of timepoint ids)
M = number of covariates
// data
s = sample id for each obs
t = timepoint id for each obs
event = integer indicating if there was an event at time t for sample s
x = matrix of real-valued covariates at time t for sample n [N, X]
obs_t = observed end time for interval for timepoint for that obs
*/
// Jacqueline Buros Novik <jackinovik@gmail.com>
functions {
matrix spline(vector x, int N, int H, vector xi, int P) {
matrix[N, H + P] b_x; // expanded predictors
for (n in 1:N) {
for (p in 1:P) {
b_x[n,p] <- pow(x[n],p-1); // x[n]^(p-1)
}
for (h in 1:H)
b_x[n, h + P] <- fmax(0, pow(x[n] - xi[h],P-1));
}
return b_x;
}
}
data {
// dimensions
int<lower=1> N;
int<lower=1> S;
int<lower=1> T;
int<lower=0> M;
// data matrix
int<lower=1, upper=N> s[N]; // sample id
int<lower=1, upper=T> t[N]; // timepoint id
int<lower=0, upper=1> event[N]; // 1: event, 0:censor
matrix[N, M] x; // explanatory vars
// timepoint data
vector<lower=0>[T] t_obs;
vector<lower=0>[T] t_dur;
}
transformed data {
vector[T] log_t_dur;
int n_trans[S, T];
log_t_dur = log(t_obs);
// n_trans used to map each sample*timepoint to n (used in gen quantities)
// map each patient/timepoint combination to n values
for (n in 1:N) {
n_trans[s[n], t[n]] = n;
}
// fill in missing values with n for max t for that patient
// ie assume "last observed" state applies forward (may be problematic for TVC)
// this allows us to predict failure times >= observed survival times
for (samp in 1:S) {
int last_value;
last_value = 0;
for (tp in 1:T) {
// manual says ints are initialized to neg values
// so <=0 is a shorthand for "unassigned"
if (n_trans[samp, tp] <= 0 && last_value != 0) {
n_trans[samp, tp] = last_value;
} else {
last_value = n_trans[samp, tp];
}
}
}
}
parameters {
vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t
real<lower=0> baseline_sigma;
real log_baseline_mu;
vector[M] beta; // beta-intercept
vector<lower=0>[M] beta_time_sigma;
vector[T-1] raw_beta_time_deltas[M]; // for each coefficient
// change in coefficient value from previous time
}
transformed parameters {
vector[N] log_hazard;
vector[T] log_baseline;
vector[T] beta_time[M];
vector[T] beta_time_deltas[M];
// adjust baseline hazard for duration of each period
log_baseline = log_baseline_raw + log_t_dur;
// compute timepoint-specific betas
// offsets from previous time
for (coef in 1:M) {
beta_time_deltas[coef][1] = 0;
for (time in 2:T) {
beta_time_deltas[coef][time] = raw_beta_time_deltas[coef][time-1];
}
}
// coefficients for each timepoint T
for (coef in 1:M) {
beta_time[coef] = beta[coef] + cumulative_sum(beta_time_deltas[coef]);
}
// compute log-hazard for each obs
for (n in 1:N) {
real log_linpred;
log_linpred <- 0;
for (coef in 1:M) {
// for now, handle each coef separately
// (to be sure we pull out the "right" beta..)
log_linpred = log_linpred + x[n, coef] * beta_time[coef][t[n]];
}
log_hazard[n] = log_baseline_mu + log_baseline[t[n]] + log_linpred;
}
}
model {
// priors on time-varying coefficients
for (m in 1:M) {
raw_beta_time_deltas[m][1] ~ normal(0, 100);
for(i in 2:(T-1)){
raw_beta_time_deltas[m][i] ~ normal(raw_beta_time_deltas[m][i-1], beta_time_sigma[m]);
}
}
beta_time_sigma ~ cauchy(0, 1);
beta ~ cauchy(0, 1);
// priors on baseline hazard
log_baseline_mu ~ normal(0, 1);
baseline_sigma ~ normal(0, 1);
log_baseline_raw[1] ~ normal(0, 1);
for (i in 2:T) {
log_baseline_raw[i] ~ normal(log_baseline_raw[i-1], baseline_sigma);
}
// model
event ~ poisson_log(log_hazard);
}
generated quantities {
real log_lik[N];
vector[T] baseline;
int y_hat_mat[S, T]; // ppcheck for each S*T combination
real y_hat_time[S]; // predicted failure time for each sample
int y_hat_event[S]; // predicted event (0:censor, 1:event)
// compute raw baseline hazard, for summary/plotting
baseline = exp(log_baseline_raw);
// log_likelihood for loo-psis
for (n in 1:N) {
log_lik[n] <- poisson_log_lpmf(event[n] | log_hazard[n]);
}
// posterior predicted values
for (samp in 1:S) {
int sample_alive;
sample_alive = 1;
for (tp in 1:T) {
if (sample_alive == 1) {
int n;
int pred_y;
real log_linpred;
real log_haz;
// determine predicted value of y
n = n_trans[samp, tp];
// (borrow code from above to calc linpred)
// but use sim tp not t[n]
log_linpred = 0;
for (coef in 1:M) {
// for now, handle each coef separately
// (to be sure we pull out the "right" beta..)
log_linpred = log_linpred + x[n, coef] * beta_time[coef][tp];
}
log_haz = log_baseline_mu + log_baseline[tp] + log_linpred;
// now, make posterior prediction
if (log_haz < log(pow(2, 30)))
pred_y = poisson_log_rng(log_haz);
else
pred_y = 9;
// mark this patient as ineligible for future tps
// note: deliberately make 9s ineligible
if (pred_y >= 1) {
sample_alive = 0;
y_hat_time[samp] = t_obs[tp];
y_hat_event[samp] = 1;
}
// save predicted value of y to matrix
y_hat_mat[samp, tp] = pred_y;
}
else if (sample_alive == 0) {
y_hat_mat[samp, tp] = 9;
}
} // end per-timepoint loop
// if patient still alive at max
//
if (sample_alive == 1) {
y_hat_time[samp] = t_obs[T];
y_hat_event[samp] = 0;
}
} // end per-sample loop
}Addendum / Disclaimer
- I am aware that survivalstan hasn’t been updated in the last 7 years (according to github). I have not implemented the above models to unearth any errors or write a competitor. I believe but haven’t checked, that the “actual” models used by survivalstan are “more” correct. I was mainly curious whether I could do it, and I wanted to see how well StanBlocks.jl does.
- I’ve skipped the
pem_survival_model_gammamodel showcased at https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_gamma%20with%20simulated%20data.html because I did not understand why the widths of the timeslabs should affect the shape parameter of the Gamma prior. Only after implementing the time varying models did I discover the models at https://nbviewer.org/github/hammerlab/survivalstan/blob/master/example-notebooks/Test%20new_gamma_survival_model%20with%20simulated%20data.ipynb. Also, the “Worked examples” page lists a “User-supplied PEM survival model with gammahazard”, though for some reason it does not show up in the sidebar for either of the other examples, compare https://jburos.github.io/survivalstan/examples/Example-using-pem_survival_model.html, https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_gamma%20with%20simulated%20data.html, https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_randomwalk%20with%20simulated%20data.html and https://jburos.github.io/survivalstan/examples/Test%20pem_survival_model_timevarying%20with%20simulated%20data.html.